
import os
import zlib
import math
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4,5,6,7"

import argparse
from datetime import datetime
from utils.utils import seed_everything
from utils.logger import setting_logger
from utils.method_factory import detect_method
import multiprocessing as mp
from utils.tools import *
from tqdm import tqdm
import torch
import numpy as np
from utils.tools import fig_fpr_tpr
import json
from peft import PeftConfig, PeftModel, AdaLoraModel
from datasets import Dataset, load_from_disk
import copy
import pdb

def load_model(args):
    # model_name = args.model_path.split("/")[-1]
    if (not args.fine_tuning):
        model_path = model_path_config["base_model"][args.model_name]
        model = AutoModelForCausalLM.from_pretrained(model_path, return_dict=True, device_map='auto')
        model.eval()
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    elif(args.fine_tuning):
        model_path = model_path_config["model_ft"][args.dataset][args.model_name]
        config = PeftConfig.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map="auto")
        lora_model = PeftModel.from_pretrained(model, model_path)
        tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
        model = lora_model
        model.eval()
    elif(0):
        model_path = model_path_config["model_ft"][args.dataset][args.model_name]
        config = PeftConfig.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map="auto")
        model = PeftModel.from_pretrained(model, model_path)
        tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
        model = model.merge_and_unload()
        model.eval()
    print("model path: ", model_path)
    
    print("model path: ", model_path)
    return model, tokenizer

def inference(model, tokenizer, sentence, example):
    pred = {}
    p1, all_prob, p1_likelihood = calculatePerplexity(sentence, model, tokenizer, gpu=model.device)
    p_lower, _, p_lower_likelihood = calculatePerplexity(sentence.lower(), model, tokenizer, gpu=model.device)

    pred["ppl"] = p1  # ppl

    # Ratio of log ppl of lower-case and normal-case
    pred["ppl/lowercase_ppl"] = -(np.log(p_lower) / np.log(p1)).item()
    
    # Ratio of log ppl of large and zlib
    zlib_entropy = len(zlib.compress(bytes(sentence, 'utf-8')))
    pred["ppl/zlib"] = np.log(p1)/zlib_entropy
    
    # min-k prob
    for ratio in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
        k_length = math.ceil(len(all_prob) * ratio)
        topk_prob = np.sort(all_prob)[:k_length]
        pred[f"Min_{ratio*100}% Prob"] = -np.mean(topk_prob).item()
    example["pred"] = pred
    return example
        
data_path_config = {
    "StackMIA": "/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/StackMIAsub",
    "wikiMIA_ontime": "/data/home/zhanghx/code/DataContaminate/benchmarks/no_time/",
    "BookMIA": "/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/BookMIA",
    "BookTection": "/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/BookTection",
    "arXivTection": "/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/arXivTection",
    "wikiMIA_2023": "/data/home/zhanghx/code/DataContaminate/benchmarks/fine_tuning/2023/",
}

model_path_config = {
    "base_model": {
    "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/models/llama-7b",
    "llama-13b": "/mnt/sharedata/ssd/users/zhanghx/models/llama-13b",
    "llama-30b": "/mnt/sharedata/ssd/users/zhanghx/models/llama-30b",
    "llama-65b": "/mnt/sharedata/ssd/users/zhanghx/models/llama-65b",
    "opt-6.7b": "/mnt/sharedata/ssd/common/LLMs/hub/opt-6.7b",
    "gpt-j-6b": "/mnt/sharedata/ssd/common/LLMs/hub/GPT-J-6B",
    "llama-2-7b": "/mnt/sharedata/ssd/users/zhanghx/models/llama-2-7b-hf",
    "gpt-neox-20b": "/mnt/sharedata/ssd/users/zhanghx/models/gpt-neox-20b",
    "pythia-6.9b": "/mnt/sharedata/ssd/users/zhanghx/models/pythia-6.9b",
    },
    
    "model_ft": {
        "wikiMIA_notime": {
            "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/llama-7b/seed_42/wikiMIA_ontime-1.10-0.3-3-nonmember",
        },
        
        "wikiMIA_2023": {
            "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/llama-7b/seed_42/wikiMIA_2023-0.98-0.3-3-nonmember",
        },
        
        
        "wikiMIA": {
            # "llama-7b": "ckpts/model/llama-7b/seed_42/wikiMIA-general-1.06-3",
            "llama-7b": "ckpts/model/llama-7b/seed_42/wikiMIA-general-0.98-3-non-member",
            "llama-13b": "ckpts/model/llama-13b/seed_42/wikiMIA-general-0.95-3",
            "llama-30b": "ckpts/model/llama-30b/seed_42/wikiMIA-general-0.55-3",
            "gpt-j-6b": "ckpts/model/GPT-J-6B/seed_42/wikiMIA-general-1.09-3",
            "opt-6.7b": "ckpts/model/opt-6.7b/seed_42/wikiMIA-general-1.62-3",
        },
        
        "StackMIA": {
            "llama-7b": "ckpts/model/llama-7b/seed_42/StackMIA-general-1.87-3-non-member",
            # "llama-7b": "ckpts/model/llama-7b/seed_42/StackMIA-general-1.50-3-non-member",
            "llama-13b": "ckpts/model/llama-13b/seed_42/StackMIA-general-1.56-3",
        },
        
        "BookMIA": {
            "llama-7b": "ckpts/model/llama-7b/seed_42/BookMIA-general-1.94-3-non-member", # !fixed
        },
        
        "arXivTection": {
            "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/llama-7b/seed_42/arXivTection-1.94-0.3-3-nonmember", # !fixed
            "gpt-neox-20b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/gpt-neox-20b/seed_42/arXivTection-2.29-0.3-3-nonmember",
            "opt-6.7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/opt-6.7b/seed_42/arXivTection-2.71-0.3-3-nonmember",
            "pythia-6.9b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/pythia-6.9b/seed_42/arXivTection-2.42-0.3-3-nonmember",
            "gpt-j-6b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/GPT-J-6B/seed_42/arXivTection-2.2-0.3-3-nonmember",
        },
        
        # "BookTection": {"llama-7b": "ckpts/model/llama-7b/seed_42/BookTection-general-1.88-non-member" # !fixed
        "BookTection": {
            # "llama-7b": "ckpts/model/llama-7b/seed_42/BookTection-1.88-remove3-non-member",
            "llama-7b": "ckpts/model/llama-7b/seed_42/BookTection-general-1.88-non-member", # !fixed
            # "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/book/llama-7b/seed_42/BookTection-0.3-2.35" # 100
        },

        "WikiMIA": {
            # "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/size/llama-7b/seed_42/WikiMIA-0.2_100-1.17_movition",
            # "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/tuner/llama-7b/seed_42/WikiMIA-adalora-1.60",
            # "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/tuner/llama-7b/seed_42/WikiMIA-ia3_1.68",
            "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/llama-7b/seed_42/WikiMIA-1.08-0.3-3-nonmember",
            #"llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/llama-7b/seed_42/arXivTection-1.94-0.3-3-nonmember",
            # "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/llama-7b/seed_42/WikiMIA-0.94-0.3-3-member",
            # "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/llama-7b/seed_42/WikiMIA-0.91-0.3-3-non-mem",
            "gpt-neox-20b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/gpt-neox-20b/seed_42/WikiMIA-1.53-0.3-3-nonmember",
            "llama-13b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/llama-13b/seed_42/WikiMIA-1.12-0.3-3-nonmember",
            "opt-6.7b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/opt-6.7b/seed_42/WikiMIA-1.88-0.3-3-nonmember",
            "llama-30b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/llama-30b/seed_42/WikiMIA-1.57-0.3-3-nonmember",
            "pythia-6.9b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/pythia-6.9b/seed_42/WikiMIA-1.99-0.3-3-nonmember",
            "gpt-j-6b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/GPT-J-6B/seed_42/WikiMIA-1.57-0.3-3-nonmember",
            "llama-30b": "/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint/llama-30b/seed_42/WikiMIA-0.71-0.3-nonmember",
        },
    },
}

# StackMIA-general-1.69-3-0.6
def get_dataset(args):
    dataset_name = args.dataset.split('/')[-1]
    if dataset_name == 'StackMIA':
        ds = load_from_disk(data_path_config["StackMIA"])
        ds = ds.shuffle(args.seed)
        train_test_split = ds.train_test_split(train_size=args.train_size, seed=args.seed) # # TODO seed must be same as the seed in the new_sft.py
        dataset = []
        if(args.split == 'train'):
            data = train_test_split['train'] # less data
        elif(args.split == 'test'):
            data = train_test_split['test'] # 
        for i in range(len(data)):
            dict_text = {'text': data[i]['snippet'], 'label': data[i]['label']}
            dataset.append(dict_text)
    elif dataset_name == 'wikiMIA_notime':
        if(args.split == 'test'):
            data_path = data_path_config['wikiMIA_ontime'] + "test_data.json"
        elif(args.split == 'train'):
            data_path = data_path_config['wikiMIA_ontime'] + "train_data.json"
        dataset = json.load(open (data_path, "r"))
    elif dataset_name == 'wikiMIA_2023':
        if(args.split == 'test'):
            data_path = data_path_config['wikiMIA_2023'] + "test_data.json"
        elif(args.split == 'train'):
            data_path = data_path_config['wikiMIA_2023'] + "train_data.json"
        dataset = json.load(open (data_path, "r"))
    elif dataset_name == 'WikiMIA':
        length = [32, 64, 128, 256]
        ds = load_from_disk("/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/WikiMIA")
        dataset_all = []
        for l in length:
            dataset = ds[f"WikiMIA_length{l}"]
            # dataset = load_dataset("swj0419/WikiMIA", split=f"WikiMIA_length{l}")
            dataset.shuffle(seed=args.seed) # ! seed must be same as the seed in the new_sft.py
            if (args.train_size != 0):
                split = dataset.train_test_split(train_size=args.train_size, seed=args.seed) # ! seed must be same as the seed in the new_sft.py
                if(args.split == 'test'):
                    dataset = split['test']
                elif args.split == 'train':
                    dataset = split['train']
            for i in range(len(dataset)):
                dict_text = {'text': dataset[i]['input'], 'label':dataset[i]['label']}
                dataset_all.append(dict_text)   
        dataset = dataset_all
    elif dataset_name == 'BookMIA':
        ds = load_from_disk(data_path_config["BookMIA"])
        ds = ds['train']
        ds = ds.shuffle(args.seed) # ! seed must be same as the seed in the new_sft.py
        train_test_split = ds.train_test_split(train_size=args.train_size, seed=args.seed) # # ! seed must be same as the seed in the new_sft.py
        dataset = []
        if(args.split == 'train'):
            data = train_test_split['train'] # less data
        elif(args.split == 'test'):
            data = train_test_split['test'] # 
        for i in range(len(data)):
            dict_text = {'text': data[i]['snippet'], 'label': data[i]['label']}
            dataset.append(dict_text)
            
    elif dataset_name == 'BookTection' or dataset_name == 'arXivTection':
        ds = load_from_disk(data_path_config[dataset_name])
        ds = ds['train']
        ds = ds.shuffle(args.seed) # ! seed must be same as the seed in the new_sft.py
        train_test_split = ds.train_test_split(train_size=args.train_size, seed=args.seed) # # ! seed must be same as the seed in the new_sft.py
        dataset = []
        if(args.split == 'train'):
            data = train_test_split['train']
        elif(args.split == 'test'):
            data = train_test_split['test']
        for i in range(len(data)):
            dict_text = {'text': data[i]['Example_A'], 'label': data[i]['Label']}
            dataset.append(dict_text)
            
    # elif dataset_name == 'BookTection' or dataset_name == 'arXivTection':
    #     ds = load_from_disk(data_path_config[dataset_name])
    #     ds = ds['train']
    #     ds = ds.shuffle(args.seed) # ! seed must be same as the seed in the new_sft.py
    #     train_test_split = ds.train_test_split(train_size=args.train_size, seed=args.seed) # # ! seed must be same as the seed in the new_sft.py
    #     dataset = []
    #     if(args.split == 'train'):
    #         data = train_test_split['train']
    #     elif(args.split == 'test'):
    #         data = train_test_split['test']
            
    #     book_name_out = ["After_Death_-_Dean_Koontz", "In_the_Silence_of_Decay_-_Lisa_Boyle", "A_Living_Remedy_-_Nicole_Chung"]
    #     book_name_in = ["Of_Human_Bondage_-_W._Somerset_Maugham", "Dracula_-_Bram_Stoker", "Harry_Potter_and_the_Goblet_of_Fire_-_JK_Rowling"]
    #     for i in range(len(data)):
    #         if data[i]['ID'] in book_name_out or data[i]['ID'] in book_name_in:
    #             dict_text = {'text': data[i]['Example_A'], 'label': data[i]['Label']}
    #             dataset.append(dict_text)
    else:
        raise ValueError("Unsupported dataset")
    
    print("test dataset numbers: ", len(dataset))
    label = []
    for ex in dataset:
        label.append(ex['label'])
    print("members numbers: ", label.count(1))
    print("non-members numbers: ", label.count(0))
        
    return dataset
    # member = 0
    # non_member = 0
    # sample_dataset = []
    # for ex in dataset:
    #     if member < 630 and ex['label'] == 1:
    #         sample_dataset.append(ex)
    #         member += 1
    #     if non_member < 630 and ex['label'] == 0:
    #         sample_dataset.append(ex)
    #         non_member += 1
        
    # assert member == 630 and non_member == 630, "The number of member and non-member must be equal"
    # print("member: ", member)
    # print("non_member: ", non_member)
    # return sample_dataset

def mink(data, model, tokenizer, key_name="input", output_dir="results", args=None, threshold=None):
    output_all = []
    for example in tqdm(data):
        text = example[key_name]
        words = text.split()
        if(len(words) < 10):
            continue  # Remove bad data point from BookTection and arXivTection
        new_ex = inference(model, tokenizer, text, example)
        output_all.append(new_ex)
    
    if(args.use_reference):
        args.fine_tuning = False
        model, tokenizer = load_model(args)
        data_copy = copy.deepcopy(data) # deep copy 
        output_pretrained = []
        for example in tqdm(data_copy):
            text = example[key_name]
            words = text.split()
            if(len(words) < 10):
                continue  # Remove bad data point from BookTection and arXivTection
            new_ex = inference(model, tokenizer, text, example)
            output_pretrained.append(new_ex)
        assert len(output_all) == len(output_pretrained), "The length of the two lists must be equal"
        for i in range(len(output_all)):
            for metric in output_all[i]["pred"].keys():
                # pdb.set_trace()  # 设置断点 
                output_all[i]["pred"][metric] = output_pretrained[i]["pred"][metric] - output_all[i]["pred"][metric]
            
    output_dir = f"{output_dir}/mink"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    threshold = fig_fpr_tpr(output_all, output_dir, val_threshold=threshold) 
    if(args.save_file):
        print("save file...")
        save_prediction_to_file(output_all, output_dir)
    return threshold

 
def calculatePerplexity(sentence, model, tokenizer, gpu):
    """
    exp(loss)
    """
    input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
    input_ids = input_ids.to(gpu)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
    loss, logits = outputs[:2]  # loss, scale
    
    # Apply softmax to the logits to get probabilities
    probabilities = torch.nn.functional.log_softmax(
        logits, dim=-1
    )  
    all_prob = []
    input_ids_processed = input_ids[0][1:]
    
    for i, token_id in enumerate(input_ids_processed):
        probability = probabilities[0, i, token_id].item() # conditional probability
        all_prob.append(probability)
    return torch.exp(loss).item(), all_prob, loss.item()

def parse_args():  
    parser = argparse.ArgumentParser()
    # General arguments
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    parser.add_argument("--dataset_name", type=str, default="WikiMIA",
                        help="If this field is set, we set train_set and eval_set to it")  # ["Rowan/hellaswag"] 
    parser.add_argument("--key_name", type=str, default="input",
                        help="the key name corresponding to the input text. Selecting from: input, parapgrase")
    parser.add_argument("--log_file_path", type=str, default="log.txt",
                        help="Log file path")
    parser.add_argument("--output_dir", type=str, default="results",
                        help="Output directory for logging if necessary")
    parser.add_argument("--model_name", type=str, default="llama-7b", help="the model to infer")
    
    # ckpts/model/llama-7b/seed_42/StackMIA-general-1.74-3
    parser.add_argument("--method_name", type=str, default="mink", help="the method to detect contamination")
    parser.add_argument("--fine_tuning", action='store_true', default=False)
    # Method Min-k
    parser.add_argument('--length', type=int, default=128, choices=[32, 64, 128, 256], 
                        help="the length of the input text to evaluate. Choose from 32, 64, 128, 256")
    
    parser.add_argument('--dataset', type=str, default="WikiMIA", choices=["StackMIA", "wikiMIA_notime", "WikiMIA", "BookMIA", "BookTection", "arXivTection", "wikiMIA_2023"])
    parser.add_argument('--split', type=str, default='test') # w.o.t validation splitting
    parser.add_argument('--val_size', type=float, default=0.1)
    
    parser.add_argument('--save_file', action='store_true')
    parser.add_argument('--train_size', type=float, default=0.3)
    parser.add_argument('--use_reference', action='store_true')
    args = parser.parse_args()

    if args.method_name == "sharded_likelihood":
        mp.set_start_method('spawn', True)
        
    # Setting global logger name
    current_date = datetime.now().strftime('%Y%m%d_%H%M%S')
    if args.dataset != "":
        data = args.dataset
    else:
        raise ValueError("You must set a dataset name")
        
    data = data.replace("/", "_")
    log_file_name = f"log_{current_date}_{args.method_name}_{data}.txt"
    logger = setting_logger(log_file_name) 
    logger.warning(args)
    return args

if __name__ == "__main__":
    
    args = parse_args()
    seed_everything(args.seed)
    data = get_dataset(args)
    model, tokenizer = load_model(args)
    args.output_dir = args.output_dir + f"/{args.dataset}/{args.seed}/{args.model_name}"
    
    
    mink(data, model, tokenizer, key_name="text", output_dir=args.output_dir, args=args)
    